#pragma once

#include "SupervisedDimensionReduction.hpp"

namespace zzz {
template<zuint N, typename T>
class LDA : public SupervisedDimensionReduction<N,T>
{
public:
  template<zuint N2>
  void DimensionReduce(const vector<Vector<N,T> > &data, vector<Vector<N2,T> > &newdata)
  {
    newdata.clear();
    
  }

private:
  /// INFO: THIS FUNCTION IS NOT FINISHED, NEED TO GO THROUGH AGAIN
  void DoTrain(const /*vector<Vector<N,T> >*/zMatrixBaseR<double> &data, const vector<int> &label)
  {
    vector<Vector<N,T> > means(label_count_, Vector<N,T>(0));
    vector<int> count(label_count_,0);
    // calculate mean for each label
    for (zuint i=0; i<data.size(); i++) {
      means[label[i]] += data[i];
      count[label[i]] ++;
    }
    for (zuint i=0; i<label_count_; i++) {
      means[i] /= count[i];
    }

    // calculate within class scatter matrix
    zMatrix<double> Sw=Zeros(label_count_, label_count_);
    for (zuint i=0; i<data.size(); i++) {
      zVector<double> diff = Dress(data[i] - means[label[i]]);
      Sw += diff * Trans(diff);
    }

    // calculate between class scatter matrix
    zMatrix<double> Sb=Zeros(label_count_, label_count_);
    Vector<N,T> means_mean = Mean(means);
    for (zuint i=0; i<label_count_; i++) {
      zVector<double> diff = Dress(means[i] - means_mean);
      Sb += diff * Trans(diff) * count[i];
    }

    EigenVec=Invert(Sw) * Sb;
    ZCHECK(EigenSym(EigenVec, EigenVal));
  }
  zMatrix<double> EigenVec, EigenVal;
};
};  // namespace zzz